#!/usr/bin/env python
# -*- coding: utf-8; py-indent-offset:4 -*-
###############################################################################
#
# Copyright (C) 2015-2020 Daniel Rodriguez
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
###############################################################################
from __future__ import (absolute_import, division, print_function,
                        unicode_literals)

import backtrader as bt
import os
from collections import OrderedDict
from backtrader.utils import AutoOrderedDict
from backtrader.utils.datehelper import *
from backtrader import TimeFrame
import pandas as pd

__all__ = ['DrawDown', 'TimeDrawDown', 'TimeDrawDownHistory']

DATE_FORMAT = "%Y-%m-%d"
NO_TIMEFRAME_DT_KEY = -1


class DrawDownPeriod:
    def __init__(self):
        self.dd = 0.0  # 回撤比例
        self.dd_value = 0.0  # 回撤金额
        self.ddlen = 0  # 回撤时长
        self.ddlen2 = 0  # 回撤持续时长 最大回撤周期从回撤开始到恢复的时长
        self.max_value = 0  # 当前净值创新高峰值
        self.dd_start = 0  # 回撤开始时间
        self.dd_low = 0  # 回撤最低谷时间
        self.dd_close = 0  # 回撤恢复时间
        self.dd_start_val = 0  # 回撤开始时间
        self.dd_low_val = 0  # 回撤最低谷时间
        self.dd_close_val = 0  # 回撤恢复时间
        self.dd_status = False  # 回撤状态

        self.mdd = 0  # 最大回撤
        self.mddlen = 0  # 最大回撤周期
        self.mddlen2 = 0  # 最大回撤持续时长
        self.mddvalue = 0  # 最大回撤金额
        self.maxmddlen = 0  # 最长最大回撤周期
        self.maxmddlen2 = 0  # 最长最大回撤持续时长

        self.curr_dt = None
        self.curr_value = None

        self.history = []

    def reset(self):
        self.dd = 0.0  # 回撤比例
        self.dd_value = 0.0  # 回撤金额
        self.ddlen = 0  # 回撤时长
        self.ddlen2 = 0  # 回撤持续时长 最大回撤周期从回撤开始到恢复的时长
        self.dd_start = 0  # 回撤开始时间
        self.dd_low = 0  # 回撤最低谷时间
        self.dd_close = 0  # 回撤恢复时间
        self.dd_start_val = 0  # 回撤开始时间
        self.dd_low_val = 0  # 回撤最低谷时间
        self.dd_close_val = 0  # 回撤恢复时间
        self.dd_status = False  # 回撤状态

        # self.mdd = 0  # 最大回撤
        # self.mddlen = 0  # 最大回撤周期
        # self.mddlen2 = 0  # 最大回撤持续时长
        # self.mddvalue = 0  # 最大回撤金额
        # self.maxmddlen = 0  # 最长最大回撤周期
        # self.maxmddlen2 = 0  # 最长最大回撤持续时长
        # self.curr_dt = None
        # self.curr_value = None

    def add_history(self):
        self.history.append([
            date2str(self.dd_start, DATE_FORMAT),
            date2str(self.dd_low, DATE_FORMAT),
            date2str(self.dd_close, DATE_FORMAT),
            self.dd_start_val,
            self.dd_low_val,
            self.dd_close_val,
            round(self.dd, 4),
            self.dd_value,
            self.ddlen,
            self.ddlen2
        ])

    def update(self, dt, value):
        self.curr_dt = dt
        self.curr_value = value
        self.max_value = max(self.max_value, value)
        # 之前不是回撤现在开始回撤
        if not self.dd_status and value < self.max_value:
            self.dd_status = True
            self.dd_start_val = value
            self.dd_start = dt

        # 之前是回撤现在结束回撤
        if self.dd_status and value >= self.max_value:
            self.close()

        _dd = (self.max_value - value) / self.max_value

        # 回撤幅度变大的时候
        if _dd > self.dd:
            self.dd_low = dt
            self.dd_low_val = value
            self.dd = _dd
            self.dd_value = self.max_value - value

    def close(self):
        self.dd_close = self.curr_dt
        self.dd_close_val = self.curr_value
        self.ddlen = (self.dd_low - self.dd_start).days + 1
        self.ddlen2 = (self.dd_close - self.dd_start).days
        self.add_history()

        if self.dd > self.mdd:
            self.mdd = self.dd
            self.mddlen = self.ddlen
            self.mddlen2 = self.ddlen2
            self.mddvalue = self.dd_value

        self.maxmddlen = max(self.maxmddlen, self.mddlen)
        self.maxmddlen2 = max(self.maxmddlen2, self.mddlen2)
        self.reset()

    def stop(self):
        if self.dd_status:
            self.close()

    def get_history(self):
        return self.history

    def get_analysis(self):
        return OrderedDict({
            "mdd": self.mdd,
            "mddvalue": self.mddvalue,
            "mddlen": self.mddlen,
            "mddlen2": self.mddlen2,
            "maxmddlen": self.maxmddlen,
            "maxmddlen2": self.maxmddlen2
        })


class DrawDownAnalysis:
    def __init__(self, dt_mode=bt.TimeFrame.NoTimeFrame):
        self.dt_mode = dt_mode  # 按年, 全部分开计算, 不支持月
        self.pre_dtkey = None
        self.period_mdd = DrawDownPeriod()
        self.rets = OrderedDict()
        self.history = OrderedDict()

    def update(self, dt, value):
        dtkey = self.curr_dt_key(dt, self.dt_mode)
        if not self.pre_dtkey or self.pre_dtkey == dtkey:
            self.period_mdd.update(dt, value)
            self.pre_dtkey = dtkey
        else:
            self._close()
            self.period_mdd = DrawDownPeriod()
            self.period_mdd.update(dt, value)
            self.pre_dtkey = dtkey

    def _close(self):
        self.period_mdd.stop()
        self.rets[self.pre_dtkey] = self.period_mdd.get_analysis()
        self.history[self.pre_dtkey] = self.period_mdd.get_history()

    def stop(self):
        self._close()

    def get_history(self):
        return self.history

    def get_analysis(self):
        return self.rets

    def curr_dt_key(self, dt, dt_mode):
        if dt_mode == bt.TimeFrame.Months:
            dt_key = dt.year * 100 + dt.month
        elif dt_mode == bt.TimeFrame.Years:
            dt_key = dt.year
        else:
            dt_key = NO_TIMEFRAME_DT_KEY
        return dt_key


class DrawDown(bt.Analyzer):
    '''This analyzer calculates trading system drawdowns stats such as drawdown
    values in %s and in dollars, max drawdown in %s and in dollars, drawdown
    length and drawdown max length

    Params:

      - ``fund`` (default: ``None``)

        If ``None`` the actual mode of the broker (fundmode - True/False) will
        be autodetected to decide if the returns are based on the total net
        asset value or on the fund value. See ``set_fundmode`` in the broker
        documentation

        Set it to ``True`` or ``False`` for a specific behavior

    Methods:

      - ``get_analysis``

        Returns a dictionary (with . notation support and subdctionaries) with
        drawdown stats as values, the following keys/attributes are available:

        - ``drawdown`` - drawdown value in 0.xx %
        - ``moneydown`` - drawdown value in monetary units
        - ``len`` - drawdown length

        - ``max.drawdown`` - max drawdown value in 0.xx %
        - ``max.moneydown`` - max drawdown value in monetary units
        - ``max.len`` - max drawdown length
    '''

    params = (
        ('fund', None),
    )

    def start(self):
        super(DrawDown, self).start()
        if self.p.fund is None:
            self._fundmode = self.strategy.broker.fundmode
        else:
            self._fundmode = self.p.fund

    def create_analysis(self):
        self.rets = AutoOrderedDict()  # dict with . notation

        self.rets.len = 0
        self.rets.drawdown = 0.0
        self.rets.moneydown = 0.0

        self.rets.max.len = 0.0
        self.rets.max.drawdown = 0.0
        self.rets.max.moneydown = 0.0

        self._maxvalue = float('-inf')  # any value will outdo it

    def stop(self):
        self.rets._close()  # . notation cannot create more keys

    # 先调用notify，然后调用next
    def notify_fund(self, cash, value, fundvalue, shares):
        if not self._fundmode:
            self._value = value  # record current value
            self._maxvalue = max(self._maxvalue, value)  # update peak value
        else:
            self._value = fundvalue  # record current value
            self._maxvalue = max(self._maxvalue, fundvalue)  # update peak

    def next(self):
        r = self.rets

        # calculate current drawdown values
        r.moneydown = moneydown = self._maxvalue - self._value
        r.drawdown = drawdown = 100.0 * moneydown / self._maxvalue

        # maxximum drawdown values
        r.max.moneydown = max(r.max.moneydown, moneydown)
        r.max.drawdown = maxdrawdown = max(r.max.drawdown, drawdown)

        r.len = r.len + 1 if drawdown else 0
        r.max.len = max(r.max.len, r.len)


"""
在制定的时间周期内计算回撤
"""


class TimeDrawDown(bt.TimeFrameAnalyzerBase):
    params = (
        ('fund', None),
        ('dt_mode', bt.TimeFrame.NoTimeFrame),
    )

    def __init__(self):
        super(TimeDrawDown, self).__init__()
        self.dda = DrawDownAnalysis(dt_mode=self.p.dt_mode)

    def start(self):
        super(TimeDrawDown, self).start()
        if self.p.fund is None:
            self._fundmode = self.strategy.broker.fundmode
        else:
            self._fundmode = self.p.fund

    # 当前周期已经over的时候就计算一次回撤
    def on_dt_over(self):
        if not self._fundmode:
            value = self.strategy.broker.getvalue()
        else:
            value = self.strategy.broker.fundvalue

        dt = self.strategy.data.datetime.date(0)
        self.dda.update(dt, value)

    def stop(self):
        self.dda.stop()
        rets = self.dda.get_analysis()
        self.rets.update(rets)
        # self.rets["mdd"] = rets["mdd"]
        # self.rets["mddvalue"] = rets["mddvalue"]
        # self.rets["mddlen"] = rets["mddlen"]
        # self.rets["mddlen2"] = rets["mddlen2"]
        # self.rets["maxmddlen"] = rets["maxmddlen"]
        # self.rets["maxmddlen2"] = rets["maxmddlen2"]


"""
指定时间周期的回撤历史
"""


class TimeDrawDownHistory(bt.TimeFrameAnalyzerBase):
    params = (
        ('timeframe', bt.TimeFrame.Days),
        ('compression', 1),
        ('fund', None),
        ('csv', False),
        ('out', None),
        ('rounding', 4)
    )

    def __init__(self):
        super(TimeDrawDownHistory, self).__init__()
        self.dda = DrawDownAnalysis(dt_mode=bt.TimeFrame.NoTimeFrame)

    def start(self):
        super(TimeDrawDownHistory, self).start()
        if self.p.fund is None:
            self._fundmode = self.strategy.broker.fundmode
        else:
            self._fundmode = self.p.fund

        self.rets["date"] = ["start_date", "low_date", "close_date", "start_value", "low_value", "close_value", "mdd",
                             "mddvalue", "mddlen", "mddlen2"]

    # 当前周期已经over的时候就计算一次回撤
    def on_dt_over(self):
        if not self._fundmode:
            value = self.strategy.broker.getvalue()
        else:
            value = self.strategy.broker.fundvalue

        dt = self.strategy.data.datetime.date(0)
        self.dda.update(dt, value)

    def stop(self):
        self.dda.stop()
        if self.p.csv and self.p.out:
            params = self.get_params()
            report_dir = "_".join(map(lambda x: str(x), params.values()))
            out = os.path.join(self.p.out, report_dir)
            if not os.path.exists(out):
                os.makedirs(out)

            fpath = os.path.join(out, "mdd.csv")
            self.to_dataframe().to_csv(fpath, index=False)

    def to_dataframe(self):
        cols = self.rets["date"]
        data = self.dda.get_history()[NO_TIMEFRAME_DT_KEY]
        df = pd.DataFrame(columns=cols, data=data)
        df = df.round(self.p.rounding)
        if len(df) > 0:
            for col in cols:
                if col not in ["start_value", "low_value", "close_value", "mdd", "mddvalue"]:
                    continue
                pad = df[col].astype(str).str.len().max()
                df[col] = df[col].astype(str).str.ljust(pad, "0")
        return df

